package main

import (
	"encoding/json"
	"flag"
	"fmt"
	"github.com/glycerine/rbtree"
	"github.com/pkg/errors"
	"gopkg.in/yaml.v2"
	"net"
	"os"
	"runtime/debug"
	"sync"
	"time"
	"unsafe"
)

type HostConfiguration struct {
	ServerAddress       string `yaml:"server-address"`
	HostName            string `yaml:"host-name"`
	SecretKey           string `yaml:"secret-key"`
	HostReadTimeout     string `yaml:"host-read-timeout"`
	HostWriteTimeout    string `yaml:"host-write-timeout"`
	TargetReadTimeout   string `yaml:"target-read-timeout"`
	TargetWriteTimeout  string `yaml:"target-write-timeout"`
	SignTimeout         string `yaml:"sign-timeout"`
	ReconnectInterval   string `yaml:"reconnect-interval"`
	KeepAliveInterval   string `yaml:"keep-alive-interval"`
	MaxPacketBodySize   int    `yaml:"max-packet-body-size"`
	TransportBufferSize int    `yaml:"transport-buffer-size"`
}

type Host struct {
	ServerAddress       string
	HostName            string
	SecretKey           string
	HostReadTimeout     time.Duration
	HostWriteTimeout    time.Duration
	TargetReadTimeout   time.Duration
	TargetWriteTimeout  time.Duration
	SignTimeout         time.Duration
	ReconnectInterval   time.Duration
	KeepAliveInterval   time.Duration
	MaxPacketBodySize   int
	TransportBufferSize int
	Lines               *rbtree.Tree
	LineTreeLock        *sync.Mutex
	Conn                net.Conn
	Err                 interface{}
	PktHeader           *PacketHeader
}

type Line struct {
	Host          *Host
	LineNum       int32
	Network       string
	ProxyPort     uint16
	TargetAddress string
	TargetConn    net.Conn
	HostConn      net.Conn
	ConnChannel   chan bool
	Err           interface{}
}

type ErrWrap struct {
	err interface{}
}

var (
	kConfigPath string
	kIsShowHelp bool
	kHostConfig *HostConfiguration
	kHost       *Host
)

func init() {
	flag.StringVar(&kConfigPath, "c", "rainbow-client.yml", "Set the configuration `filename`")
	flag.BoolVar(&kIsShowHelp, "h", false, "Show help")
	flag.Usage = usage
}

func initConfig() {
	var err error

	flag.Parse()
	if kIsShowHelp {
		usage()
		os.Exit(0)
	}

	kHostConfig = &HostConfiguration{
		ServerAddress:       "",
		HostName:            "",
		SecretKey:           "",
		HostReadTimeout:     "3m",
		HostWriteTimeout:    "3m",
		TargetReadTimeout:   "3m",
		TargetWriteTimeout:  "3m",
		SignTimeout:         "30s",
		ReconnectInterval:   "5s",
		KeepAliveInterval:   "1m",
		MaxPacketBodySize:   65536 - PacketHeaderSize,
		TransportBufferSize: 65536,
	}

	file, err := os.Open(kConfigPath)
	if err != nil {
		panic(err)
	}
	defer closeFile(file)

	err = yaml.NewDecoder(file).Decode(&kHostConfig)
	if err != nil {
		panic(err)
	}

	if kHostConfig.ServerAddress == "" {
		panic("The configuration property 'server-address' is required")
	}
	if kHostConfig.HostName == "" {
		panic("The configuration property 'host-name' is required")
	}
	if kHostConfig.SecretKey == "" {
		panic("The configuration property 'secret-key' is required")
	}

	configContent, err := yaml.Marshal(kHostConfig)
	if err != nil {
		panic(err)
	}

	fmt.Println(string(configContent))

	kHost = &Host{
		ServerAddress:       kHostConfig.ServerAddress,
		HostName:            kHostConfig.HostName,
		SecretKey:           kHostConfig.SecretKey,
		HostReadTimeout:     parseDurationConfig("host-read-timeout", kHostConfig.HostReadTimeout),
		HostWriteTimeout:    parseDurationConfig("host-write-timeout", kHostConfig.HostWriteTimeout),
		TargetReadTimeout:   parseDurationConfig("host-read-timeout", kHostConfig.TargetReadTimeout),
		TargetWriteTimeout:  parseDurationConfig("host-write-timeout", kHostConfig.TargetWriteTimeout),
		SignTimeout:         parseDurationConfig("sign-timeout", kHostConfig.SignTimeout),
		ReconnectInterval:   parseDurationConfig("reconnect-interval", kHostConfig.ReconnectInterval),
		KeepAliveInterval:   parseDurationConfig("keep-alive-interval", kHostConfig.KeepAliveInterval),
		MaxPacketBodySize:   kHostConfig.MaxPacketBodySize,
		TransportBufferSize: kHostConfig.TransportBufferSize,
		Lines:               rbtree.NewTree(lineCompare),
		LineTreeLock:        &sync.Mutex{},
	}
}

func usage() {
	_, _ = fmt.Fprintln(os.Stderr, "rainbow-client version: 1.0")
	_, _ = fmt.Fprintln(os.Stderr, "Usage: rainbow-client [-c filename]")
	_, _ = fmt.Fprintln(os.Stderr, "Options: ")
	flag.PrintDefaults()
}

func connect() {
	for {
		handleHost(kHost)
		<-time.After(kHost.ReconnectInterval)
	}
}

func handleHost(host *Host) {
	host.Conn, host.Err = net.Dial("tcp", kHost.ServerAddress)
	if host.Err != nil {
		logError(host.Err, fmt.Sprintf(
			"Connect to server '%s' failed", kHost.ServerAddress))
		return
	}
	defer releaseHost(host)

	host.Err = hostConn(host)
	if host.Err != nil {
		return
	}

	go keepAlive(host)

	host.Err = resolvePacket(host)
}

func releaseHost(host *Host) {
	closeConn(host.Conn)
	if host.Err == nil {
		host.Err = recover()
		if host.Err != nil {
			debug.PrintStack()
		}
	}
	if host.Err == nil {
		host.Err = ""
	}
	logError(host.Err,
		fmt.Sprintf("Host conn '%s %s->%s %s' released",
			host.Conn.LocalAddr().Network(),
			host.Conn.LocalAddr().String(),
			host.Conn.RemoteAddr().Network(),
			host.Conn.RemoteAddr().String()))
}

func hostConn(host *Host) error {
	hostConnBody := &HostConnBody{
		HostName: host.HostName,
	}
	hostConnBody.GenerateSign(host.SecretKey)
	hostConnBodyBytes, err := json.Marshal(hostConnBody)
	if err != nil {
		return err
	}
	hostConnBodySize := len(hostConnBodyBytes)
	hostConnHeader := &PacketHeader{
		Magic:         PacketMagic,
		Version:       PacketVersion,
		OperationType: OptTypeHostConn,
		ContentSize:   int32(hostConnBodySize),
	}
	hostConnHeaderBytes := StructToBytes(unsafe.Pointer(hostConnHeader), PacketHeaderSize)
	hostConnSize := PacketHeaderSize + hostConnBodySize
	hostConnBytes := make([]byte, hostConnSize)
	copy(hostConnBytes, hostConnHeaderBytes)
	copy(hostConnBytes[PacketHeaderSize:], hostConnBodyBytes)
	return writeFixedData(host.Conn, hostConnBytes, host.HostWriteTimeout)
}

func keepAlive(host *Host) {
	errWrap := &ErrWrap{}
	conn := host.Conn
	defer func() {
		if errWrap.err == nil {
			errWrap.err = recover()
			if errWrap.err != nil {
				debug.PrintStack()
			}
		}
		if errWrap.err != nil {
			logError(errWrap.err, fmt.Sprintf(
				"Keep alive for host conn '%s' failed", getConnInfo(conn)))
		}
	}()
	hostSynHeader := &PacketHeader{
		Magic:         PacketMagic,
		Version:       PacketVersion,
		OperationType: OptTypeHostSyn,
		ContentSize:   0,
	}
	hostSynBytes := StructToBytes(unsafe.Pointer(hostSynHeader), PacketHeaderSize)
	for {
		<-time.After(host.KeepAliveInterval)
		errWrap.err = writeFixedData(conn, hostSynBytes, host.HostWriteTimeout)
		if errWrap.err != nil {
			return
		}
	}
}

func resolvePacket(host *Host) error {
	var pktHeaderBytes = make([]byte, PacketHeaderSize)
	for {
		err := readFixedData(host.Conn, pktHeaderBytes, host.HostReadTimeout)
		if err != nil {
			return err
		}

		host.PktHeader = (*PacketHeader)(BytesToStruct(pktHeaderBytes))

		if host.PktHeader.Magic != PacketMagic ||
			host.PktHeader.Version > PacketVersion ||
			int(host.PktHeader.ContentSize) > host.MaxPacketBodySize {
			return errors.New(fmt.Sprint("Incorrect packet header:", host.PktHeader))
		}

		switch host.PktHeader.OperationType {
		case OptTypeHostAck:
			err = handleHostAck(host)
			break
		case OptTypeLineCall:
			err = handleLineCall(host)
			break
		default:
			err = errors.New(fmt.Sprint("Incorrect 'OperationType' from packet header:", host.PktHeader))
		}
		if err != nil {
			return err
		}
	}
}

func handleHostAck(host *Host) error {
	if host.PktHeader.ContentSize != 0 {
		return errors.New(fmt.Sprint("Incorrect 'ContentSize' from packet header:", host.PktHeader))
	}
	return nil
}

func handleLineCall(host *Host) error {
	var err error
	contentSize := int(host.PktHeader.ContentSize)
	contentBytes := make([]byte, contentSize)

	err = readFixedData(host.Conn, contentBytes, host.HostReadTimeout)
	if err != nil {
		return err
	}

	lineCallBody := &LineCallBody{}
	err = json.Unmarshal(contentBytes, lineCallBody)
	if err != nil {
		return errors.Wrap(err,
			fmt.Sprintf("Parse the json '%s' to 'LineConnBody' failed", contentBytes))
	}

	if !lineCallBody.CheckTimestamp(host.SignTimeout) {
		return errors.New(
			fmt.Sprintf("Incorrect timestamp '%s'", contentBytes))
	}

	if !lineCallBody.CheckSign(host.SecretKey) {
		return errors.New(
			fmt.Sprintf("Incorrect sign '%s'", contentBytes))
	}

	host.LineTreeLock.Lock()
	lineItem := host.Lines.Get(&Line{LineNum: lineCallBody.LineNum})
	if lineItem != nil {
		host.LineTreeLock.Unlock()
		return nil
	}
	line := &Line{
		Host:          host,
		LineNum:       lineCallBody.LineNum,
		Network:       lineCallBody.Network,
		ProxyPort:     lineCallBody.ProxyPort,
		TargetAddress: lineCallBody.TargetAddress,
		ConnChannel:   make(chan bool),
	}
	succeed := host.Lines.Insert(line)
	host.LineTreeLock.Unlock()
	if !succeed {
		return nil
	}
	go handleLine(line)
	return nil
}

func handleLine(line *Line) {
	defer releaseLine(line)

	go lineCallback(line)

	line.TargetConn, line.Err = net.Dial(line.Network, line.TargetAddress)
	if line.Err != nil {
		return
	}

	succeed := waitLineCallback(line)
	if !succeed {
		line.Err = errors.New(fmt.Sprintf(
			"Wait line '%d' of host callback failed", line.LineNum))
		return
	}

	line.Err = transportConnData(
		line.TargetConn,
		line.HostConn,
		line.Host.TargetReadTimeout,
		line.Host.HostWriteTimeout,
		line.Host.TransportBufferSize)
}

func lineCallback(line *Line) {
	errWrap := &ErrWrap{}

	line.HostConn, errWrap.err = net.Dial("tcp", line.Host.ServerAddress)
	if errWrap.err != nil {
		line.ConnChannel <- false
		return
	}
	defer func() {
		defer closeConn(line.HostConn)
		if errWrap.err == nil {
			errWrap.err = recover()
			if errWrap.err != nil {
				debug.PrintStack()
			}
		}
		if errWrap.err != nil {
			logError(errWrap.err,
				fmt.Sprintf("Host conn '%s' released", getConnInfo(line.HostConn)))
		}
	}()

	errWrap.err = lineConn(line)
	if errWrap.err != nil {
		line.ConnChannel <- false
		return
	}

	line.ConnChannel <- true
	errWrap.err = transportConnData(
		line.HostConn,
		line.TargetConn,
		line.Host.HostReadTimeout,
		line.Host.TargetWriteTimeout,
		line.Host.TransportBufferSize)
}

func waitLineCallback(line *Line) bool {
	select {
	case <-time.After(line.Host.HostReadTimeout):
		return false
	case success := <-line.ConnChannel:
		return success
	}
}

func releaseLine(line *Line) {
	closeConn(line.TargetConn)
	line.Host.LineTreeLock.Lock()
	line.Host.Lines.DeleteWithKey(line)
	line.Host.LineTreeLock.Unlock()
	if line.Err == nil {
		line.Err = recover()
		if line.Err != nil {
			debug.PrintStack()
		}
	}
	if line.Err == nil {
		line.Err = ""
	}
	logError(line.Err,
		fmt.Sprintf("Line: '%d', route: '%s %d->%s', target conn: '%s', host conn: '%s' released",
			line.LineNum,
			line.Network,
			line.ProxyPort,
			line.TargetAddress,
			getConnInfo(line.TargetConn),
			getConnInfo(line.HostConn)))
}

func lineConn(line *Line) error {
	lineConnBody := &LineConnBody{
		HostName: line.Host.HostName,
		LineNum:  line.LineNum,
	}
	lineConnBody.GenerateSign(line.Host.SecretKey)
	lineConnBodyBytes, err := json.Marshal(lineConnBody)
	if err != nil {
		return err
	}
	lineConnBodySize := len(lineConnBodyBytes)
	lineConnHeader := &PacketHeader{
		Magic:         PacketMagic,
		Version:       PacketVersion,
		OperationType: OptTypeLineConn,
		ContentSize:   int32(lineConnBodySize),
	}
	lineConnHeaderBytes := StructToBytes(unsafe.Pointer(lineConnHeader), PacketHeaderSize)
	lineConnSize := PacketHeaderSize + lineConnBodySize
	lineConnBytes := make([]byte, lineConnSize)
	copy(lineConnBytes, lineConnHeaderBytes)
	copy(lineConnBytes[PacketHeaderSize:], lineConnBodyBytes)
	return writeFixedData(line.HostConn, lineConnBytes, line.Host.HostWriteTimeout)
}

func transportConnData(
	readConn net.Conn,
	writeConn net.Conn,
	readTimeout time.Duration,
	writeTimeout time.Duration,
	bufferSize int) error {
	buffer := make([]byte, bufferSize)
	for {
		err := readConn.SetReadDeadline(time.Now().Add(readTimeout))
		if err != nil {
			return err
		}
		readSize, err := readConn.Read(buffer)
		if err != nil {
			return err
		}
		err = writeFixedData(writeConn, buffer[0:readSize], writeTimeout)
		if err != nil {
			return err
		}
	}
}

func readFixedData(conn net.Conn, bytes []byte, timeout time.Duration) error {
	if bytes == nil {
		return nil
	}
	size := len(bytes)
	if size == 0 {
		return nil
	}
	readSize := 0
	for readSize < size {
		err := conn.SetReadDeadline(time.Now().Add(timeout))
		if err != nil {
			return err
		}
		n, err := conn.Read(bytes[readSize:size])
		if err != nil {
			return err
		}
		readSize += n
	}
	return nil
}

func writeFixedData(conn net.Conn, bytes []byte, timeout time.Duration) error {
	if bytes == nil {
		return nil
	}
	size := len(bytes)
	if size == 0 {
		return nil
	}
	writeSize := 0
	for writeSize < size {
		err := conn.SetWriteDeadline(time.Now().Add(timeout))
		if err != nil {
			return err
		}
		n, err := conn.Write(bytes[writeSize:size])
		if err != nil {
			return err
		}
		writeSize += n
	}
	return nil
}

func getConnInfo(conn net.Conn) string {
	if conn == nil {
		return "__unconnected__"
	}
	return fmt.Sprintf("%s %s->%s %s",
		conn.LocalAddr().Network(),
		conn.LocalAddr().String(),
		conn.RemoteAddr().Network(),
		conn.RemoteAddr().String())
}

func parseDurationConfig(name string, value string) time.Duration {
	duration, err := time.ParseDuration(value)
	if err != nil {
		panic(errors.Wrap(err, fmt.Sprintf("Incorrect duration property '%s': %s", name, value)))
	}
	return duration
}

func closeFile(file *os.File) {
	err := file.Close()
	if err != nil {
		logError(err, fmt.Sprintf("Close file '%s' failed", file.Name()))
	}
}

func closeConn(conn net.Conn) {
	if conn == nil {
		return
	}
	err := conn.Close()
	if err != nil {
		logError(err, fmt.Sprintf("Close conn failed"))
	}
}

func lineCompare(a, b rbtree.Item) int {
	lineNumA := a.(*Line).LineNum
	lineNumB := b.(*Line).LineNum
	if lineNumA > lineNumB {
		return 1
	}
	if lineNumA < lineNumB {
		return -1
	}
	return 0
}

func logError(err interface{}, message string) {
	_, _ = fmt.Fprintln(
		os.Stderr,
		fmt.Sprintf("[%s] %s:",
			time.Now().Format("2006-01-02T03:04:05.999999999Z07:00"),
			message),
		err)
}

func main() {
	initConfig()
	connect()
}
